{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "4e128e4337fd5c906540c112bc1d4e0fd2f38ef3"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import logging\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pytorch_pretrained_bert.tokenization import BertTokenizer, BasicTokenizer\n",
    "from pytorch_pretrained_bert.modeling import BertModel\n",
    "\n",
    "from helperbot import BaseBot, TriangularLR\n",
    "\n",
    "\n",
    "'''\n",
    "Fork and eddit from:\n",
    "https://www.kaggle.com/ceshine/pytorch-bert-baseline-public-score-0-54\n",
    "\n",
    "We use this notebook to generate BERT embeddings for two mentions and the gender pronoun.\n",
    "We do not remove punctuation during data pre-processing\n",
    "\n",
    "This part can also be used as base deep learning model\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# df_train = pd.read_csv(\"gap-test.tsv\", delimiter=\"\\t\")\n",
    "df_test = pd.read_csv(\"gap-development.tsv\", delimiter=\"\\t\")\n",
    "df_train_val = pd.concat([\n",
    "    pd.read_csv(\"gap-test.tsv\", delimiter=\"\\t\"),\n",
    "    pd.read_csv(\"gap-validation.tsv\", delimiter=\"\\t\")\n",
    "], axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "We modify the output of Head Model, so that it will only output extracted bert embeddings\n",
    "'''\n",
    "\n",
    "class Head(nn.Module):\n",
    "    \"\"\"The MLP submodule\"\"\"\n",
    "    def __init__(self, bert_hidden_size: int):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.bert_hidden_size = bert_hidden_size\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.BatchNorm1d(bert_hidden_size * 3),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(bert_hidden_size * 3, 512),    # bert_hidden_size * 3\n",
    "            nn.ReLU(),\n",
    "            nn.BatchNorm1d(512),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(512, 3)\n",
    "        )\n",
    "        for i, module in enumerate(self.fc):\n",
    "            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):\n",
    "                nn.init.constant_(module.weight, 1)\n",
    "                nn.init.constant_(module.bias, 0)\n",
    "                print(\"Initing batchnorm\")\n",
    "            elif isinstance(module, nn.Linear):\n",
    "                if getattr(module, \"weight_v\", None) is not None:\n",
    "                    nn.init.uniform_(module.weight_g, 0, 1)\n",
    "                    nn.init.kaiming_normal_(module.weight_v)\n",
    "                    print(\"Initing linear with weight normalization\")\n",
    "                    assert model[i].weight_g is not None\n",
    "                else:\n",
    "                    nn.init.kaiming_normal_(module.weight)\n",
    "                    print(\"Initing linear\")\n",
    "                nn.init.constant_(module.bias, 0)\n",
    "                \n",
    "    def forward(self, bert_outputs, offsets):\n",
    "        assert bert_outputs.size(2) == self.bert_hidden_size\n",
    "        \n",
    "        extracted_outputs = bert_outputs.gather(1, offsets.unsqueeze(2).expand(-1, -1, bert_outputs.size(2))\n",
    "        ).view(bert_outputs.size(0), 3, -1)\n",
    "\n",
    "        \n",
    "        \n",
    "        '''\n",
    "        We modify the output of Head Model, so that it will only output extracted bert embeddings\n",
    "        '''\n",
    "        return extracted_outputs\n",
    "        #return self.fc(extracted_outputs)\n",
    "\n",
    "    \n",
    "class GAPModel(nn.Module):\n",
    "    \"\"\"The main model.\"\"\"\n",
    "    def __init__(self, bert_model: str, device: torch.device):\n",
    "        super().__init__()\n",
    "        self.device = device\n",
    "        self.bert_hidden_size = 1024\n",
    "        self.bert = BertModel.from_pretrained(bert_model).to(device)\n",
    "        self.head = Head(self.bert_hidden_size).to(device)\n",
    "    \n",
    "    def forward(self, token_tensor, offsets):\n",
    "        token_tensor = token_tensor.to(self.device)\n",
    "        \n",
    "        bert_outputs, _ =  self.bert(\n",
    "            token_tensor, \n",
    "            attention_mask=(token_tensor > 0).long(), \n",
    "            token_type_ids=None, \n",
    "            output_all_encoded_layers=False)  \n",
    "        head_outputs = self.head(bert_outputs, offsets.to(self.device))\n",
    "        '''\n",
    "        Only output BERT embeddings here.\n",
    "        '''\n",
    "        return head_outputs            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "d58d18c34f5df9ec8f8d8fb048ae6c10fbf9914a"
   },
   "outputs": [],
   "source": [
    "BERT_MODEL = 'bert-large-uncased'\n",
    "CASED = True\n",
    "\n",
    "\n",
    "def insert_tag(row):\n",
    "    to_be_inserted = sorted([\n",
    "        (row[\"A-offset\"], \" [THISISA] \"),\n",
    "        (row[\"B-offset\"], \" [THISISB] \"),\n",
    "        (row[\"Pronoun-offset\"], \" [THISISP] \")\n",
    "    ], key=lambda x: x[0], reverse=True)\n",
    "    \n",
    "    text = row[\"Text\"]\n",
    "    for offset, tag in to_be_inserted:\n",
    "        text = text[:offset] + tag + text[offset:]\n",
    "    return text\n",
    "\n",
    "\n",
    "def tokenize(text, tokenizer):\n",
    "    entries = {}\n",
    "    final_tokens = []\n",
    "    for token in tokenizer.tokenize(text):\n",
    "        if token in (\"[THISISA]\", \"[THISISB]\", \"[THISISP]\"):\n",
    "            entries[token] = len(final_tokens)\n",
    "            continue\n",
    "        final_tokens.append(token)\n",
    "    return final_tokens, (entries[\"[THISISA]\"], entries[\"[THISISB]\"], entries[\"[THISISP]\"])\n",
    "\n",
    "\n",
    "class GAPDataset(Dataset):\n",
    "    def __init__(self, df, tokenizer, labeled=True):\n",
    "        self.labeled = labeled\n",
    "        if labeled:\n",
    "            tmp = df[[\"A-coref\", \"B-coref\"]].copy()\n",
    "            tmp[\"Neither\"] = ~(df[\"A-coref\"] | df[\"B-coref\"])\n",
    "            self.y = tmp.values.astype(\"bool\")\n",
    "        \n",
    "        # Extracts the tokens and offsets(positions of A, B, and P)\n",
    "        self.offsets = []\n",
    "        self.tokens = []\n",
    "        for _, row in df.iterrows():\n",
    "            text = insert_tag(row)\n",
    "            tokens, offsets = tokenize(text, tokenizer)\n",
    "            self.offsets.append(offsets)\n",
    "            self.tokens.append(tokenizer.convert_tokens_to_ids([\"[CLS]\"] + tokens + [\"[SEP]\"]))\n",
    "        \n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.tokens)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.labeled:\n",
    "            return self.tokens[idx], self.offsets[idx], self.y[idx]\n",
    "        \n",
    "        return self.tokens[idx], self.offsets[idx], None\n",
    "\n",
    "    \n",
    "    \n",
    "def collate_examples(batch, truncate_len=500):\n",
    "\n",
    "    transposed = list(zip(*batch))\n",
    "    \n",
    "\n",
    "    max_len = min( max((len(x) for x in transposed[0])),  truncate_len)\n",
    "    tokens = np.zeros((len(batch), max_len), dtype=np.int64)\n",
    "    for i, row in enumerate(transposed[0]):\n",
    "        row = np.array(row[:truncate_len])\n",
    "        tokens[i, :len(row)] = row\n",
    "    \n",
    "\n",
    "    token_tensor = torch.from_numpy(tokens)\n",
    "    offsets = torch.stack([torch.LongTensor(x) for x in transposed[1]], dim=0) + 1 # Account for the [CLS] token\n",
    "    one_hot_labels = torch.stack([torch.from_numpy(x.astype(\"uint8\")) for x in transposed[2]], dim=0)\n",
    "    \n",
    "    _, labels = one_hot_labels.max(dim=1) \n",
    "    return token_tensor, offsets, labels\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained(\n",
    "    BERT_MODEL,\n",
    "    do_lower_case = CASED,\n",
    "    never_split = (\"[UNK]\", \"[SEP]\", \"[PAD]\", \"[CLS]\", \"[MASK]\", \"[THISISA]\", \"[THISISB]\", \"[THISISP]\")\n",
    ")\n",
    "\n",
    "tokenizer.vocab[\"[THISISA]\"] = -1\n",
    "tokenizer.vocab[\"[THISISB]\"] = -1\n",
    "tokenizer.vocab[\"[THISISP]\"] = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "69689365738454b33649a14d83eea49cc1b18687"
   },
   "outputs": [],
   "source": [
    "train_val_ds = GAPDataset(df_train_val, tokenizer)\n",
    "test_ds = GAPDataset(df_test, tokenizer)\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    train_val_ds,\n",
    "    collate_fn = collate_examples,\n",
    "    batch_size = 1,\n",
    "    shuffle=False,\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    test_ds,\n",
    "    collate_fn = collate_examples,\n",
    "    batch_size = 1,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "a85f87ed8d73b520e39a5dc07da1838867ac2653"
   },
   "outputs": [],
   "source": [
    "def children(m):\n",
    "    return m if isinstance(m, (list, tuple)) else list(m.children())\n",
    "def set_trainable_attr(m, b):\n",
    "    m.trainable = b\n",
    "    for p in m.parameters():\n",
    "        p.requires_grad = b\n",
    "def apply_leaf(m, f):\n",
    "    c = children(m)\n",
    "    if isinstance(m, nn.Module):\n",
    "        f(m)\n",
    "    if len(c) > 0:\n",
    "        for l in c:\n",
    "            apply_leaf(l, f)\n",
    "def set_trainable(l, b):\n",
    "    apply_leaf(l, lambda m: set_trainable_attr(m, b))\n",
    "\n",
    "model = GAPModel(BERT_MODEL, torch.device(\"cuda:1\"))\n",
    "set_trainable(model.bert, False)\n",
    "set_trainable(model.head, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_outputs = []\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for token_tensor, offsets, labels in train_loader:\n",
    "        prediction = model(token_tensor, offsets)\n",
    "        bert_outputs.append(prediction)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(bert_outputs, open('others_bert_outputs.pkl', \"wb\"))\n",
    "#pickle.dump(bert_outputs, open('test_others_bert_outputs.pkl', \"wb\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
